# Copyright 2023 Google LLC
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load(
    "//:build_defs.bzl",
    "xnnpack_benchmark",
    "xnnpack_cxx_library",
    "xnnpack_if_kleidiai_enabled",
    "xnnpack_kleidiai_defines",
    "xnnpack_optional_dnnl_copts",
    "xnnpack_optional_dnnl_deps",
    "xnnpack_optional_gemmlowp_copts",
    "xnnpack_optional_gemmlowp_deps",
    "xnnpack_optional_ruy_copts",
    "xnnpack_optional_ruy_deps",
    "xnnpack_optional_tflite_copts",
    "xnnpack_optional_tflite_deps",
    "xnnpack_visibility",
)
load(
    "//:build_params.bzl",
    "xnnpack_select_if",
)

MICROKERNEL_BENCHMARK_DEPS = [
    ":bench_utils",
    "//:aligned_allocator",
    "//:all_microkernels",
    "//:common",
    "//:hardware_config",
    "//:math",
    "//:microkernels_h",
    "//:microparams_init",
    "//:microparams",
    "//:packing",
    "//:params",
    "//:xnnpack_h",
]

OPERATOR_BENCHMARK_DEPS = [
    ":bench_utils",
    "//:XNNPACK",
    "//:aligned_allocator",
    "//:cache",
    "//:common",
    "//:math",
]

xnnpack_cxx_library(
    name = "bench_utils",
    srcs = ["utils.cc"],
    hdrs = ["utils.h"],
    copts = select({
        "//:cpuinfo_enabled": ["-DXNN_ENABLE_CPUINFO=1"],
        "//conditions:default": ["-DXNN_ENABLE_CPUINFO=0"],
    }),
    visibility = xnnpack_visibility(),
    deps = [
        "//:allocator",
        "//:common",
        "//:hardware_config",
        "//:memory",
        "//:params",
        "//:xnnpack_h",
        "@com_google_benchmark//:benchmark",
    ] + xnnpack_select_if(
        "//:cpuinfo_enabled",
        ["@cpuinfo"],
    ),
)

cc_library(
    name = "conv",
    hdrs = ["conv.h"],
    deps = [
        "@com_google_benchmark//:benchmark",
    ],
)

cc_library(
    name = "dwconv",
    hdrs = ["dwconv.h"],
    deps = [
        "@com_google_benchmark//:benchmark",
    ],
)

cc_library(
    name = "spmm",
    hdrs = ["spmm.h"],
    deps = [
        "@com_google_benchmark//:benchmark",
    ],
)

######################### Benchmarks for micro-kernels #########################

xnnpack_cxx_library(
    name = "gemm_benchmark",
    srcs = [
        "gemm-benchmark.cc",
    ],
    hdrs = [
        "gemm.h",
        "gemm-benchmark.h",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:config_hdrs",
        "@com_google_benchmark//:benchmark",
    ],
)

[xnnpack_benchmark(
    name = "%s_bench" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":gemm_benchmark",
        "//:allocator",
        "//:isa_checks",
    ],
) for kernel in [
    "bf16_gemm",
    "qd8_f16_qb4w_gemm",
    "qd8_f32_qb4w_gemm",
    "qd8_f16_qc8w_gemm",
    "qd8_f32_qc8w_gemm",
    "qd8_f16_qc4w_gemm",
    "qd8_f32_qc4w_gemm",
    "qs8_qc8w_gemm_fp32",
    "qu8_gemm_fp32",
    "qu8_gemm_rndnu",
    "f16_f32acc_gemm",
    "f16_gemm",
    "f32_qc4w_gemm",
    "f32_qc8w_gemm",
]]

[xnnpack_benchmark(
    name = "%s_bench" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    copts = xnnpack_optional_ruy_copts() + xnnpack_optional_gemmlowp_copts(),
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":gemm_benchmark",
        "//:allocator",
        "//:isa_checks",
    ] + xnnpack_optional_ruy_deps() + xnnpack_optional_gemmlowp_deps(),
) for kernel in [
    "qs8_gemm",
    "qu8_gemm",
    "f16_gemm_minmax",
    "f32_gemm",
    "f32_gemm_minmax",
    "f32_gemm_goi_minmax",
]]

xnnpack_benchmark(
    name = "f32_bgemm_bench",
    srcs = [
        "bgemm.h",
        "f32-bgemm.cc",
    ],
    copts = xnnpack_optional_ruy_copts(),
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:allocator",
    ] + xnnpack_optional_ruy_deps(),
)

xnnpack_benchmark(
    name = "qp8_f32_qc4w_gemm_bench",
    srcs = [
        "qp8-f32-qc4w-gemm.cc",
    ],
    defines = xnnpack_kleidiai_defines(),
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":gemm_benchmark",
        "//:isa_checks",
    ] + xnnpack_if_kleidiai_enabled([
        "@KleidiAI//kai/ukernels/matmul",
    ]),
)

[xnnpack_benchmark(
    name = "%s_bench" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "vcvt-benchmark.h",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS,
) for kernel in [
    "qs8_f16_vcvt",
    "qs8_f32_vcvt",
    "qs8_vcvt",
    "qs16_qs8_vcvt",
    "qu8_f32_vcvt",
    "qu8_vcvt",
    "f16_f32_vcvt",
    "f32_f16_vcvt",
    "f32_qs8_vcvt",
    "f32_qu8_vcvt",
]]

[xnnpack_benchmark(
    name = "%s_bench" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "rsum-benchmark.h",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS,
) for kernel in [
    "qs8_rdsum",
    "qs8_rsum",
    "f16_rsum",
    "f16_f32acc_rsum",
    "f32_rsum",
    "f16_f32acc_rdsum",
    "f32_rdsum",
]]

[xnnpack_benchmark(
    name = "%s_bench" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS,
) for kernel in [
    "qs8_vadd",
    "qs8_vaddc",
    "qs8_vhswish",
    "qs8_vlrelu",
    "qs8_vmul",
    "qs8_vmulc",
    "qu8_vadd",
    "qu8_vaddc",
    "qu8_vhswish",
    "qu8_vlrelu",
    "qu8_vmul",
    "qu8_vmulc",
    "f16_gavgpool_cw",
    "f16_raddstoreexpminusmax",
    "f16_rmax",
    "f16_rminmax",
    "f16_rmin",
    "f32_gavgpool_cw",
    "f32_raddexpminusmax",
    "f32_raddextexp",
    "f32_raddstoreexpminusmax",
    "f32_rmax",
    "f32_rminmax",
    "f32_rmin",
    "f32_vscaleexpminusmax",
    "f32_vscaleextexp",
    "f16_vcmul",
    "f32_vcmul",
    "s16_rmaxabs",
    "s16_window",
    "u32_filterbank_accumulate",
    "u32_filterbank_subtract",
    "u32_vlog",
    "u64_u32_vsqrtshift",
    "i16_vlshift",
    "cs16_vsquareabs",
    "cs16_bfly4",
    "cs16_fftr",
    "x8_lut",
    "xx_transposev",
    "x8_transpose",
    "x16_transpose",
    "x24_transpose",
    "x32_transpose",
    "x64_transpose",
]]

xnnpack_benchmark(
    name = "qs8_requantization_bench",
    srcs = [
        "qs8-requantization.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + ["//:requantization_stubs"],
)

xnnpack_benchmark(
    name = "qu8_requantization_bench",
    srcs = [
        "qu8-requantization.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + ["//:requantization_stubs"],
)

xnnpack_benchmark(
    name = "qs8_dwconv_bench",
    srcs = [
        "qs8-dwconv.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":dwconv",
        "//:indirection",
        "//:microkernel_configs",
        "//:microkernel_utils",
    ],
)

xnnpack_benchmark(
    name = "f16_f32acc_igemm_bench",
    srcs = [
        "f16-f32acc-igemm.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":conv",
        "//:indirection",
    ],
)

xnnpack_benchmark(
    name = "f16_igemm_bench",
    srcs = [
        "f16-igemm.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":conv",
        "//:indirection",
    ],
)

[xnnpack_benchmark(
    name = "%s_bench" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "f16-vunary-benchmark.h",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS,
) for kernel in [
    "f16_vsigmoid",
    "f16_vtanh",
    "f16_velu",
    "f16_vabs",
    "f16_vhswish",
    "f16_vlrelu",
    "f16_vneg",
    "f16_vrndd",
    "f16_vrndne",
    "f16_vrndu",
    "f16_vrndz",
    "f16_vsqr",
    "f16_vsqrt",
    "f16_vrsqrt",
    "f16_vclamp",
]]

[xnnpack_benchmark(
    name = "%s_bench" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "f32-vunary-benchmark.h",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS,
) for kernel in [
    "f32_vabs",
    "f32_velu",
    "f32_vgelu",
    "f32_vhswish",
    "f32_vlog",
    "f32_vlrelu",
    "f32_vneg",
    "f32_vrelu",
    "f32_vrndd",
    "f32_vrndne",
    "f32_vrndu",
    "f32_vrndz",
    "f32_vsigmoid",
    "f32_vsqr",
    "f32_vsqrt",
    "f32_vrsqrt",
    "f32_vclamp",
    "f32_vtanh",
]]

xnnpack_benchmark(
    name = "f32_igemm_bench",
    srcs = [
        "f32-igemm.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":conv",
        "//:indirection",
    ],
)

xnnpack_benchmark(
    name = "f32_conv_hwc_bench",
    srcs = [
        "dconv.h",
        "f32-conv-hwc.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS,
)

xnnpack_benchmark(
    name = "f16_conv_hwc2chw_bench",
    srcs = [
        "dconv.h",
        "f16-conv-hwc2chw.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS,
)

xnnpack_benchmark(
    name = "f32_conv_hwc2chw_bench",
    srcs = [
        "dconv.h",
        "f32-conv-hwc2chw.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS,
)

xnnpack_benchmark(
    name = "f16_dwconv_bench",
    srcs = [
        "f16-dwconv.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":dwconv",
        "//:indirection",
        "//:microkernel_utils",
    ],
)

xnnpack_benchmark(
    name = "f32_dwconv_bench",
    srcs = [
        "f32-dwconv.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":dwconv",
        "//:indirection",
        "//:microkernel_utils",
    ],
)

xnnpack_benchmark(
    name = "f32_dwconv2d_chw_bench",
    srcs = [
        "f32-dwconv2d-chw.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":dwconv",
        "//:indirection",
    ],
)

xnnpack_benchmark(
    name = "f16_dwconv2d_chw_bench",
    srcs = [
        "f16-dwconv2d-chw.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":dwconv",
        "//:indirection",
    ],
)

xnnpack_benchmark(
    name = "f32_spmm_bench",
    srcs = [
        "f32-spmm.cc",
        "spmm-benchmark.h",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":spmm",
        "//:isa_checks",
    ],
)

xnnpack_benchmark(
    name = "f16_spmm_bench",
    srcs = [
        "f16-spmm.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [":spmm"],
)

xnnpack_benchmark(
    name = "f32_softmax_bench",
    srcs = [
        "f32-softmax.cc",
    ],
    copts = xnnpack_optional_dnnl_copts(),
    deps = MICROKERNEL_BENCHMARK_DEPS + xnnpack_optional_dnnl_deps(),
)

xnnpack_benchmark(
    name = "f32_im2col_gemm_bench",
    srcs = [
        "f32-im2col-gemm.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":conv",
        "//:im2col",
    ],
)

xnnpack_cxx_library(
    name = "packq_benchmark",
    srcs = [
        "bgemm.h",
        "packq-benchmark.cc",
    ],
    hdrs = ["packq-benchmark.h"],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "@com_google_benchmark//:benchmark",
    ],
)

xnnpack_benchmark(
    name = "x8_packq_bench",
    srcs = [
        "bgemm.h",
        "x8-packq.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        ":packq_benchmark",
        "//:allocator",
    ],
)

xnnpack_benchmark(
    name = "x8_packw_bench",
    srcs = [
        "bgemm.h",
        "packw-benchmark.h",
        "x8-packw.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:allocator",
    ],
)

xnnpack_benchmark(
    name = "qs8_packw_bench",
    srcs = [
        "bgemm.h",
        "packw-benchmark.h",
        "qs8-packw.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:allocator",
    ],
)

xnnpack_benchmark(
    name = "x16_packw_bench",
    srcs = [
        "bgemm.h",
        "packw-benchmark.h",
        "x16-packw.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:allocator",
    ],
)

xnnpack_benchmark(
    name = "x32_packw_bench",
    srcs = [
        "bgemm.h",
        "packw-benchmark.h",
        "x32-packw.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:allocator",
    ],
)

########################### Benchmarks for operators ###########################

[xnnpack_benchmark(
    name = "%s_bench" % op,
    srcs = [
        "%s.cc" % op.replace("_", "-"),
        "unary_operator.h",
    ],
    copts = xnnpack_optional_tflite_copts(),
    tags = ["nowin32"],
    deps = OPERATOR_BENCHMARK_DEPS + xnnpack_optional_tflite_deps(),
) for op in [
    "abs",
    "bankers_rounding",
    "ceiling",
    "convert",
    "elu",
    "floor",
    "leaky_relu",
    "negate",
    "hardswish",
    "reciprocal_square_root",
    "sigmoid",
    "softmax",
    "square",
    "square_root",
    "tanh",
    "truncation",
]]

xnnpack_benchmark(
    name = "average_pooling_bench",
    srcs = ["average-pooling.cc"],
    copts = xnnpack_optional_tflite_copts(),
    tags = ["nowin32"],
    deps = OPERATOR_BENCHMARK_DEPS + xnnpack_optional_tflite_deps(),
)

xnnpack_benchmark(
    name = "batch_matrix_multiply_bench",
    srcs = ["batch-matrix-multiply.cc"],
    copts = xnnpack_optional_tflite_copts(),
    tags = ["nowin32"],
    deps = OPERATOR_BENCHMARK_DEPS + xnnpack_optional_tflite_deps(),
)

xnnpack_benchmark(
    name = "channel_shuffle_bench",
    srcs = ["channel-shuffle.cc"],
    deps = OPERATOR_BENCHMARK_DEPS,
)

xnnpack_benchmark(
    name = "convolution_bench",
    srcs = ["convolution.cc"],
    copts = xnnpack_optional_tflite_copts(),
    tags = ["nowin32"],
    deps = OPERATOR_BENCHMARK_DEPS + xnnpack_optional_tflite_deps(),
)

xnnpack_benchmark(
    name = "deconvolution_bench",
    srcs = ["deconvolution.cc"],
    copts = xnnpack_optional_tflite_copts(),
    tags = ["nowin32"],
    deps = OPERATOR_BENCHMARK_DEPS + xnnpack_optional_tflite_deps(),
)

xnnpack_benchmark(
    name = "fully_connected_bench",
    srcs = ["fully-connected.cc"],
    copts = xnnpack_optional_tflite_copts(),
    tags = ["nowin32"],
    deps = OPERATOR_BENCHMARK_DEPS + xnnpack_optional_tflite_deps(),
)

xnnpack_benchmark(
    name = "global_average_pooling_bench",
    srcs = ["global-average-pooling.cc"],
    deps = OPERATOR_BENCHMARK_DEPS,
)

xnnpack_benchmark(
    name = "max_pooling_bench",
    srcs = ["max-pooling.cc"],
    deps = OPERATOR_BENCHMARK_DEPS,
)

xnnpack_benchmark(
    name = "prelu_bench",
    srcs = ["prelu.cc"],
    copts = xnnpack_optional_tflite_copts(),
    tags = ["nowin32"],
    deps = OPERATOR_BENCHMARK_DEPS + xnnpack_optional_tflite_deps(),
)

xnnpack_benchmark(
    name = "scaled_dot_product_attention_bench",
    srcs = ["scaled-dot-product-attention.cc"],
    copts = xnnpack_optional_tflite_copts(),
    tags = ["nowin32"],
    deps = OPERATOR_BENCHMARK_DEPS + xnnpack_optional_tflite_deps(),
)

############################### E2E benchmarks ###############################

xnnpack_benchmark(
    name = "f16_dwconv_e2e_bench",
    srcs = [
        "end2end.h",
        "f16-dwconv-e2e.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:XNNPACK",
        "//:microkernel_configs",
        "//:models_h",
        "//models:fp16_mobilenet_v1",
        "//models:fp16_mobilenet_v2",
        "//models:fp16_mobilenet_v3_large",
        "//models:fp16_mobilenet_v3_small",
    ],
)

xnnpack_benchmark(
    name = "f16_gemm_e2e_bench",
    srcs = [
        "end2end.h",
        "f16-gemm-e2e.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:XNNPACK",
        "//:microkernel_configs",
        "//:models_h",
        "//models:fp16_mobilenet_v1",
        "//models:fp16_mobilenet_v2",
        "//models:fp16_mobilenet_v3_large",
        "//models:fp16_mobilenet_v3_small",
    ],
)

xnnpack_benchmark(
    name = "f32_dwconv_e2e_bench",
    srcs = [
        "end2end.h",
        "f32-dwconv-e2e.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:XNNPACK",
        "//:microkernel_configs",
        "//:models_h",
        "//models:fp32_mobilenet_v1",
        "//models:fp32_mobilenet_v2",
        "//models:fp32_mobilenet_v3_large",
        "//models:fp32_mobilenet_v3_small",
    ],
)

xnnpack_benchmark(
    name = "f32_gemm_e2e_bench",
    srcs = [
        "end2end.h",
        "f32-gemm-e2e.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:XNNPACK",
        "//:microkernel_configs",
        "//:models_h",
        "//models:fp32_mobilenet_v1",
        "//models:fp32_mobilenet_v2",
        "//models:fp32_mobilenet_v3_large",
        "//models:fp32_mobilenet_v3_small",
    ],
)

xnnpack_benchmark(
    name = "qs8_dwconv_e2e_bench",
    srcs = [
        "end2end.h",
        "qs8-dwconv-e2e.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:XNNPACK",
        "//:microkernel_configs",
        "//:models_h",
        "//models:qs8_mobilenet_v1",
        "//models:qs8_mobilenet_v2",
    ],
)

xnnpack_benchmark(
    name = "qs8_gemm_e2e_bench",
    srcs = [
        "end2end.h",
        "qs8-gemm-e2e.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:XNNPACK",
        "//:microkernel_configs",
        "//:models_h",
        "//models:qs8_mobilenet_v1",
        "//models:qs8_mobilenet_v2",
    ],
)

xnnpack_benchmark(
    name = "qu8_gemm_e2e_bench",
    srcs = [
        "end2end.h",
        "qu8-gemm-e2e.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:XNNPACK",
        "//:microkernel_configs",
        "//:models_h",
        "//models:qu8_mobilenet_v1",
        "//models:qu8_mobilenet_v2",
        "//models:qu8_mobilenet_v3_large",
        "//models:qu8_mobilenet_v3_small",
    ],
)

xnnpack_benchmark(
    name = "qu8_dwconv_e2e_bench",
    srcs = [
        "end2end.h",
        "qu8-dwconv-e2e.cc",
    ],
    deps = MICROKERNEL_BENCHMARK_DEPS + [
        "//:XNNPACK",
        "//:microkernel_configs",
        "//:models_h",
        "//models:qu8_mobilenet_v1",
        "//models:qu8_mobilenet_v2",
        "//models:qu8_mobilenet_v3_large",
        "//models:qu8_mobilenet_v3_small",
    ],
)

xnnpack_benchmark(
    name = "end2end_bench",
    srcs = ["end2end.cc"],
    deps = [
        ":bench_utils",
        "//:XNNPACK",
        "//:models_h",
        "//models:fp16_mobilenet_v1",
        "//models:fp16_mobilenet_v2",
        "//models:fp16_mobilenet_v3_large",
        "//models:fp16_mobilenet_v3_small",
        "//models:fp16_sparse_mobilenet_v1",
        "//models:fp16_sparse_mobilenet_v2",
        "//models:fp16_sparse_mobilenet_v3_large",
        "//models:fp16_sparse_mobilenet_v3_small",
        "//models:fp32_mobilenet_v1",
        "//models:fp32_mobilenet_v2",
        "//models:fp32_mobilenet_v3_large",
        "//models:fp32_mobilenet_v3_small",
        "//models:fp32_sparse_mobilenet_v1",
        "//models:fp32_sparse_mobilenet_v2",
        "//models:fp32_sparse_mobilenet_v3_large",
        "//models:fp32_sparse_mobilenet_v3_small",
        "//models:qs8_mobilenet_v1",
        "//models:qs8_mobilenet_v2",
        "//models:qs8_qc8w_mobilenet_v1",
        "//models:qs8_qc8w_mobilenet_v2",
        "//models:qu8_mobilenet_v1",
        "//models:qu8_mobilenet_v2",
        "//models:qu8_mobilenet_v3_large",
        "//models:qu8_mobilenet_v3_small",
    ],
)
